{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e6eec787",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[[ 2.326157    0.9676092   0.59912694  0.30605093 -2.3376296   3.116142\n",
      "  -1.0918891  -4.7112594   1.0828805   5.3570127   2.5092688  -1.0975481\n",
      "  -1.3621287  -0.2706312   0.75446266]\n",
      " [ 0.8203291   1.1425364  -5.5159254   2.15256    -1.228265    3.6615298\n",
      "  -2.293613    0.10767588 -5.0149612  -1.1832998   1.7172099  -0.45588365\n",
      "   0.40262952  0.7001094   1.0721042 ]\n",
      " [ 3.0388887   3.8081758  -3.1468863  -0.28015733  0.5934015   2.6223903\n",
      "   1.0071808  -2.37889    -1.1004275  -3.1837513  -2.2163894   0.15744041\n",
      "  -1.8371269  -1.4914287  -2.4161725 ]\n",
      " [ 3.6280425  -3.045486   -5.03049    -2.709865   -1.914968   -1.4497113\n",
      "   2.2239273  -0.95654523 -2.3478985  -1.5843691   1.8642083  -6.311548\n",
      "   4.387059   -1.0771842  -1.7222666 ]]\n",
      "<NDArray 4x15 @cpu(0)>\n",
      "epoch 1,loss 76685.250000\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-2-5f25c49afa44>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m    114\u001b[0m \u001b[0mtrain_iter\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdata_iter\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mx_label\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    115\u001b[0m \u001b[0mtest_iter\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdata_iter\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0my_test\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0my_label\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 116\u001b[1;33m \u001b[0mtrain_ch3\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtrain_iter\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtest_iter\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msquared_loss\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mparams\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    117\u001b[0m \u001b[1;31m#set_figsize()\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    118\u001b[0m \u001b[1;31m#plt.scatter(x[:, 1].asnumpy(), x_label[:, 0].asnumpy(), 1);  # 加分号只显示图\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-2-5f25c49afa44>\u001b[0m in \u001b[0;36mtrain_ch3\u001b[1;34m(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)\u001b[0m\n\u001b[0;32m     85\u001b[0m               params=None, lr=None):\n\u001b[0;32m     86\u001b[0m     \u001b[1;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnum_epochs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 87\u001b[1;33m         \u001b[1;32mfor\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mdata_iter\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mx_label\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     88\u001b[0m             \u001b[1;32mwith\u001b[0m \u001b[0mautograd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecord\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     89\u001b[0m                 \u001b[0my_hat\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-2-5f25c49afa44>\u001b[0m in \u001b[0;36mdata_iter\u001b[1;34m(batch_size, features, labels)\u001b[0m\n\u001b[0;32m     61\u001b[0m     \u001b[0mrandom\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindices\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# 样本的读取顺序是随机的\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     62\u001b[0m     \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnum_examples\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 63\u001b[1;33m         \u001b[0mj\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindices\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mmin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnum_examples\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     64\u001b[0m         \u001b[1;32myield\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtake\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtake\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# take函数根据索引返回对应元素\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     65\u001b[0m \u001b[1;31m# def cross_entropy(y_hat, y):\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\conda\\envs\\gluon\\lib\\site-packages\\mxnet\\ndarray\\utils.py\u001b[0m in \u001b[0;36marray\u001b[1;34m(source_array, ctx, dtype)\u001b[0m\n\u001b[0;32m    144\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0m_sparse_array\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msource_array\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mctx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    145\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 146\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0m_array\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msource_array\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mctx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    147\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    148\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\conda\\envs\\gluon\\lib\\site-packages\\mxnet\\ndarray\\ndarray.py\u001b[0m in \u001b[0;36marray\u001b[1;34m(source_array, ctx, dtype)\u001b[0m\n\u001b[0;32m   2503\u001b[0m                 \u001b[1;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'source_array must be array like object'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2504\u001b[0m     \u001b[0marr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mempty\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msource_array\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2505\u001b[1;33m     \u001b[0marr\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msource_array\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   2506\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0marr\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2507\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\conda\\envs\\gluon\\lib\\site-packages\\mxnet\\ndarray\\ndarray.py\u001b[0m in \u001b[0;36m__setitem__\u001b[1;34m(self, key, value)\u001b[0m\n\u001b[0;32m    447\u001b[0m         \u001b[0mindexing_dispatch_code\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_get_indexing_dispatch_code\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    448\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mindexing_dispatch_code\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0m_NDARRAY_BASIC_INDEXING\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 449\u001b[1;33m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_set_nd_basic_indexing\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    450\u001b[0m         \u001b[1;32melif\u001b[0m \u001b[0mindexing_dispatch_code\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0m_NDARRAY_ADVANCED_INDEXING\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    451\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_set_nd_advanced_indexing\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\conda\\envs\\gluon\\lib\\site-packages\\mxnet\\ndarray\\ndarray.py\u001b[0m in \u001b[0;36m_set_nd_basic_indexing\u001b[1;34m(self, key, value)\u001b[0m\n\u001b[0;32m    713\u001b[0m                     \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgeneric\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mvalue\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[0mshape\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    714\u001b[0m                         \u001b[0mvalue\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbroadcast_to\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mshape\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 715\u001b[1;33m                     \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_sync_copyfrom\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    716\u001b[0m                 \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m  \u001b[1;31m# value might be a list or a tuple\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    717\u001b[0m                     \u001b[0mvalue_nd\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_prepare_value_nd\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mshape\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\conda\\envs\\gluon\\lib\\site-packages\\mxnet\\ndarray\\ndarray.py\u001b[0m in \u001b[0;36m_sync_copyfrom\u001b[1;34m(self, source_array)\u001b[0m\n\u001b[0;32m    879\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    880\u001b[0m             \u001b[0msource_array\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mctypes\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata_as\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mctypes\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mc_void_p\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 881\u001b[1;33m             ctypes.c_size_t(source_array.size)))\n\u001b[0m\u001b[0;32m    882\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    883\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m_slice\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstart\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstop\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import d2lzh as d2l\n",
    "import xlrd\n",
    "import random\n",
    "import math\n",
    "from IPython import display\n",
    "from matplotlib import pyplot as plt\n",
    "from mxnet import autograd, nd\n",
    "batch_size =10\n",
    "num_inputs = 4\n",
    "num_outputs = 1\n",
    "num_hiddens=15\n",
    "\n",
    "\n",
    "w = nd.random.normal(scale=2, shape=(num_inputs, num_hiddens))\n",
    "b = nd.zeros(num_hiddens)\n",
    "w1=nd.random.normal(scale=2, shape=(num_hiddens, num_outputs))\n",
    "b1= nd.zeros(num_outputs)\n",
    "\n",
    "w.attach_grad()\n",
    "b.attach_grad()\n",
    "w1.attach_grad()\n",
    "b1.attach_grad()\n",
    "\n",
    "params=[w,b,w1,b1]\n",
    "print(w)\n",
    "def use_svg_display():\n",
    "    # 用矢量图显示\n",
    "    display.set_matplotlib_formats('svg')\n",
    "\n",
    "def set_figsize(figsize=(3.5, 2.5)):\n",
    "    use_svg_display()\n",
    "    # 设置图的尺寸\n",
    "    plt.rcParams['figure.figsize'] = figsize\n",
    "\n",
    "def squared_loss(y_hat, y):\n",
    "    return (y_hat - y) ** 2 / 2\n",
    "\n",
    "def relu(X):\n",
    "    return nd.maximum(X,0)\n",
    "\n",
    "def net(X):\n",
    "    H=relu(nd.dot(X,w)+b)\n",
    "    Y=nd.dot(H, w1) + b1\n",
    "    return Y\n",
    "\n",
    "def excel2matrix(path):\n",
    "    data = xlrd.open_workbook(path)\n",
    "    table = data.sheets()[0]\n",
    "    nrows = table.nrows  # 行数\n",
    "    ncols = table.ncols  # 列数\n",
    "    datamatrix = nd.random.normal(scale=1,shape=(nrows, ncols))\n",
    "    for i in range(nrows):\n",
    "        rows = table.row_values(i)\n",
    "        datamatrix[i,:] = rows\n",
    "    return datamatrix\n",
    " \n",
    "def data_iter(batch_size, features, labels):\n",
    "    num_examples = len(features)\n",
    "    indices = list(range(num_examples))\n",
    "    random.shuffle(indices)  # 样本的读取顺序是随机的\n",
    "    for i in range(0, num_examples, batch_size):\n",
    "        j = nd.array(indices[i: min(i + batch_size, num_examples)])\n",
    "        yield features.take(j), labels.take(j)  # take函数根据索引返回对应元素\n",
    "# def cross_entropy(y_hat, y):\n",
    "#     return -nd.pick(y_hat, y).log()\n",
    "# def accuracy(y_hat, y):\n",
    "#     return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar()\n",
    "\n",
    "# def evaluate_accuracy(data_iter, net):\n",
    "#     acc_sum, n = 0.0, 0\n",
    "#     for X, y in data_iter:\n",
    "#         y = y.astype('float32')\n",
    "#         acc_sum += (net(X).argmax(axis=1) == y).sum().asscalar()\n",
    "#         n += y.size\n",
    "#     return acc_sum / n\n",
    "\n",
    "num_epochs, lr = 500, 0.00001\n",
    "\n",
    "def sgd(params, lr, batch_size):  \n",
    "    for param in params:\n",
    "        param[:] = param - lr * param.grad / batch_size\n",
    "\n",
    "def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,\n",
    "              params=None, lr=None):\n",
    "    for epoch in range(num_epochs):\n",
    "        for X, y in data_iter(batch_size,x,x_label):\n",
    "            with autograd.record():\n",
    "                y_hat = net(X)\n",
    "#                 print('X')\n",
    "#                 print(X)\n",
    "#                 print('y_hat')\n",
    "#                 print(y_hat)\n",
    "#                 print('y')\n",
    "#                 print(y)\n",
    "#                 print('[W,b]')\n",
    "#                 print([w,b])\n",
    "                l = loss(y_hat, y)\n",
    "            l.backward()   #求梯度\n",
    "            sgd(params, lr, batch_size)    #更新wb权重   \n",
    "#         print(params)\n",
    "        train_l_sum =loss(net(x),x_label)  #误差\n",
    "        print('epoch %d,loss %f' % (epoch + 1,train_l_sum.mean().asnumpy()))\n",
    "\n",
    "\n",
    "pathX = '1标准化_1_300_307井.xls'  #  113.xlsx 在当前文件夹下\n",
    "pathX2 = '1_300_307井_label.xls'  #  113.xlsx 在当前文件夹下\n",
    "pathX3 = '1标准化_1_300_307井pre.xls'  #  113.xlsx 在当前文件夹下\n",
    "x = excel2matrix(pathX)\n",
    "x_label=excel2matrix(pathX2)\n",
    "y_test=excel2matrix(pathX3)\n",
    "y_label=nd.zeros((y_test.shape[0],1))\n",
    "\n",
    "train_iter=data_iter(batch_size,x,x_label)\n",
    "test_iter=data_iter(batch_size,y_test,y_label)\n",
    "train_ch3(net, train_iter, test_iter, squared_loss, num_epochs, batch_size,params, lr)\n",
    "#set_figsize()\n",
    "#plt.scatter(x[:, 1].asnumpy(), x_label[:, 0].asnumpy(), 1);  # 加分号只显示图\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3290baa",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:gluon] *",
   "language": "python",
   "name": "conda-env-gluon-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
